{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 递归神经网络从头开始的实现\n",
    ":label:`sec_rnn_scratch`\n",
    "\n",
    "\n",
    "在本节中，我们将根据 :numref:`sec_rnn`中的描述，\n",
    "从头开始基于循环神经网络实现字符级语言模型。\n",
    "这样的模型将在H.G.威尔斯的时光机器数据集上训练。\n",
    "和前面 :numref:`sec_language_model`中介绍过的一样，\n",
    "我们先读取数据集。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/PlotUtils.java\n",
    "\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/Animator.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/timemachine/Vocab.java\n",
    "%load ../utils/timemachine/RNNModelScratch.java\n",
    "%load ../utils/timemachine/TimeMachine.java\n",
    "%load ../utils/timemachine/SeqDataLoader.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@FunctionalInterface\n",
    "public interface TriFunction<T, U, V, W> {\n",
    "    public W apply(T t, U u, V v);\n",
    "}\n",
    "\n",
    "@FunctionalInterface\n",
    "public interface QuadFunction<T, U, V, W, R> {\n",
    "    public R apply(T t, U u, V v, W w);\n",
    "}\n",
    "\n",
    "@FunctionalInterface\n",
    "public interface SimpleFunction<T> {\n",
    "    public T apply();\n",
    "}\n",
    "\n",
    "@FunctionalInterface\n",
    "public interface voidFunction<T> {\n",
    "    public void apply(T t);\n",
    "}\n",
    "\n",
    "@FunctionalInterface\n",
    "public interface voidTwoFunction<T, U> {\n",
    "    public void apply(T t, U u);\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int batchSize = 32;\n",
    "int numSteps = 35;\n",
    "Pair<List<NDList>, Vocab> timeMachine = SeqDataLoader.loadDataTimeMachine(batchSize, numSteps, false, 10000, manager);\n",
    "List<NDList> trainIter = timeMachine.getKey();\n",
    "Vocab vocab = timeMachine.getValue();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 独热编码\n",
    "\n",
    "回想一下，每个标记在 `trainIter`.\n",
    "中都表示为一个数字索引。\n",
    "将这些指数直接输入神经网络可能会使其难以识别\n",
    "学习\n",
    "我们通常将每个标记表示为更具表现力的特征向量。\n",
    "最简单的表示法称为 (*one-hot encoding*),\n",
    "介绍了\n",
    "in :numref:`subsec_classification-problem`.\n",
    "\n",
    "简言之，我们将每个索引映射到一个不同的单位向量：假设词汇表中不同标记的数量为 $N$ (`vocab.length()`) ，标记索引的范围为0到 $N-1$。\n",
    "如果token的索引是整数 $i$，那么我们创建一个长度为 $N$ 的所有0的向量，并将元素的位置 $i$ 设置为1。\n",
    "此向量是原始token的一个热向量。索引为0和2的一个独热向量如下所示。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "manager.create(new int[] {0, 2}).oneHot(vocab.length())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们每次采样的小批量的形状是（批量大小、时间步数）。\n",
    "`oneHot` 函数将这样一个小批量转换为三维数据数组，最后一个维度等于词汇表大小(`vocab.length()`).\n",
    "我们经常转换输入，以便获得一个\n",
    "形状输出\n",
    "（时间步数、批次大小、词汇表大小）。\n",
    "这将使我们\n",
    "更方便\n",
    "循环通过最外层维度\n",
    "用于更新小批量的隐藏状态。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray X = manager.arange(10).reshape(new Shape(2,5));\n",
    "X.transpose().oneHot(28).getShape()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 初始化模型参数\n",
    "\n",
    "接下来，我们初始化的模型参数\n",
    "循环神经网络模型。\n",
    "隐藏单元数 `numHiddens` 是一个可调超参数。\n",
    "在培训语言模型时，\n",
    "输入和输出来自同一词汇表。\n",
    "因此，它们具有相同的维度，\n",
    "这等于词汇量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDList getParams(int vocabSize, int numHiddens, Device device) {\n",
    "    int numOutputs = vocabSize;\n",
    "    int numInputs = vocabSize;\n",
    "\n",
    "    // 隐藏层参数\n",
    "    NDArray W_xh = normal(new Shape(numInputs, numHiddens), device);\n",
    "    NDArray W_hh = normal(new Shape(numHiddens, numHiddens), device);\n",
    "    NDArray b_h = manager.zeros(new Shape(numHiddens), DataType.FLOAT32, device);\n",
    "    // 输出层参数\n",
    "    NDArray W_hq = normal(new Shape(numHiddens, numOutputs), device);\n",
    "    NDArray b_q = manager.zeros(new Shape(numOutputs), DataType.FLOAT32, device);\n",
    "\n",
    "    // 加上梯度\n",
    "    NDList params = new NDList(W_xh, W_hh, b_h, W_hq, b_q);\n",
    "    for (NDArray param : params) {\n",
    "        param.setRequiresGradient(true);\n",
    "    }\n",
    "    return params;\n",
    "}\n",
    "\n",
    "public static NDArray normal(Shape shape, Device device) {\n",
    "    return manager.randomNormal(0f, 0.01f, shape, DataType.FLOAT32, device);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 循环神经网络模型\n",
    "\n",
    "要定义循环神经网络模型，\n",
    "我们首先需要一个 `initRNNState` '函数\n",
    "在初始化时返回隐藏状态。\n",
    "它返回一个填充为0且形状为（批量大小、隐藏单元数）的数据数组。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDList initRNNState(int batchSize, int numHiddens, Device device) {\n",
    "    return new NDList(manager.zeros(new Shape(batchSize, numHiddens), DataType.FLOAT32, device));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面的 `rnn` 函数定义了如何计算隐藏状态和输出\n",
    "在一个时间步。\n",
    "注意\n",
    "循环神经网络模型\n",
    "循环通过`inputs`的最外层维度\n",
    "这样它就可以更新小批量的隐藏状态 `H` ，\n",
    "此外\n",
    "这里的激活函数使用 $\\tanh$ 函数。\n",
    "像\n",
    "描述于 :numref:`sec_mlp`, 该\n",
    "当元素均匀分布时， $\\tanh$ 函数的平均值为0\n",
    "分布在实数上。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<NDArray, NDList> rnn(NDArray inputs, NDList state, NDList params) {\n",
    "    // 输入的形状：（`numSteps`、`batchSize`、`vocabSize`）\n",
    "    NDArray W_xh = params.get(0);\n",
    "    NDArray W_hh = params.get(1);\n",
    "    NDArray b_h = params.get(2);\n",
    "    NDArray W_hq = params.get(3);\n",
    "    NDArray b_q = params.get(4);\n",
    "    NDArray H = state.get(0);\n",
    "\n",
    "    NDList outputs = new NDList();\n",
    "    // 'X'的形状：（'batchSize'，'vocabSize`）\n",
    "    NDArray X, Y;\n",
    "    for (int i = 0; i < inputs.size(0); i++) {\n",
    "        X = inputs.get(i);\n",
    "        H = (X.dot(W_xh).add(H.dot(W_hh)).add(b_h)).tanh();\n",
    "        Y = H.dot(W_hq).add(b_q);\n",
    "        outputs.add(Y);\n",
    "    }\n",
    "    return new Pair<>(outputs.size() > 1 ? NDArrays.concat(outputs) : outputs.get(0), new NDList(H));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义了所有需要的功能，\n",
    "接下来，我们创建一个类来包装这些函数，并存储从头实现的循环神经网络模型的参数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** 从头开始实现的RNN模型 */\n",
    "public class RNNModelScratch {\n",
    "    public int vocabSize;\n",
    "    public int numHiddens;\n",
    "    public NDList params;\n",
    "    public TriFunction<Integer, Integer, Device, NDList> initState;\n",
    "    public TriFunction<NDArray, NDList, NDList, Pair> forwardFn;\n",
    "\n",
    "    public RNNModelScratch(\n",
    "            int vocabSize,\n",
    "            int numHiddens,\n",
    "            Device device,\n",
    "            TriFunction<Integer, Integer, Device, NDList> getParams,\n",
    "            TriFunction<Integer, Integer, Device, NDList> initRNNState,\n",
    "            TriFunction<NDArray, NDList, NDList, Pair> forwardFn) {\n",
    "        this.vocabSize = vocabSize;\n",
    "        this.numHiddens = numHiddens;\n",
    "        this.params = getParams.apply(vocabSize, numHiddens, device);\n",
    "        this.initState = initRNNState;\n",
    "        this.forwardFn = forwardFn;\n",
    "    }\n",
    "\n",
    "    public Pair forward(NDArray X, NDList state) {\n",
    "        X = X.transpose().oneHot(this.vocabSize);\n",
    "        return this.forwardFn.apply(X, state, this.params);\n",
    "    }\n",
    "\n",
    "    public NDList beginState(int batchSize, Device device) {\n",
    "        return this.initState.apply(batchSize, this.numHiddens, device);\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们检查输出是否具有正确的形状，例如，以确保隐藏状态的维度保持不变。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "int numHiddens = 512;\n",
    "TriFunction<Integer, Integer, Device, NDList> getParamsFn = (a, b, c) -> getParams(a, b, c);\n",
    "TriFunction<Integer, Integer, Device, NDList> initRNNStateFn =\n",
    "        (a, b, c) -> initRNNState(a, b, c);\n",
    "TriFunction<NDArray, NDList, NDList, Pair> rnnFn = (a, b, c) -> rnn(a, b, c);\n",
    "\n",
    "NDArray X = manager.arange(10).reshape(new Shape(2, 5));\n",
    "Device device = manager.getDevice();\n",
    "\n",
    "RNNModelScratch net =\n",
    "        new RNNModelScratch(\n",
    "                vocab.length(), numHiddens, device, getParamsFn, initRNNStateFn, rnnFn);\n",
    "NDList state = net.beginState((int) X.getShape().getShape()[0], device);\n",
    "Pair<NDArray, NDList> pairResult = net.forward(X.toDevice(device, false), state);\n",
    "NDArray Y = pairResult.getKey();\n",
    "NDList newState = pairResult.getValue();\n",
    "System.out.println(Y.getShape());\n",
    "System.out.println(newState.get(0).getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以看到输出形状是（时间步数 $\\times$ batch大小，词汇表大小），而隐藏状态形状保持不变，即（批大小，隐藏单元数）。\n",
    "\n",
    "\n",
    "## 预测\n",
    "\n",
    "让我们首先定义预测函数来生成`prefix`之后的新字符，\n",
    "其中的`prefix`是一个用户提供的包含多个字符的字符串。\n",
    "在循环遍历`prefix`中的开始字符时，\n",
    "我们不断地将隐状态传递到下一个时间步，但是不生成任何输出。\n",
    "这被称为*预热*（warm-up）期，\n",
    "因为在此期间模型会自我更新（例如，更新隐状态），\n",
    "但不会进行预测。\n",
    "预热期结束后，隐状态的值通常比刚开始的初始值更适合预测，\n",
    "从而预测字符并输出它们。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** 在 `prefix` 后面生成新字符。 */\n",
    "public static String predictCh8(\n",
    "        String prefix, int numPreds, RNNModelScratch net, Vocab vocab, Device device) {\n",
    "    NDList state = net.beginState(1, device);\n",
    "    List<Integer> outputs = new ArrayList<>();\n",
    "    outputs.add(vocab.getIdx(\"\" + prefix.charAt(0)));\n",
    "    SimpleFunction<NDArray> getInput =\n",
    "            () ->\n",
    "                    manager.create(outputs.get(outputs.size() - 1))\n",
    "                            .toDevice(device, false)\n",
    "                            .reshape(new Shape(1, 1));\n",
    "    for (char c : prefix.substring(1).toCharArray()) { // 预热期\n",
    "        state = (NDList) net.forward(getInput.apply(), state).getValue();\n",
    "        outputs.add(vocab.getIdx(\"\" + c));\n",
    "    }\n",
    "\n",
    "    NDArray y;\n",
    "    for (int i = 0; i < numPreds; i++) {\n",
    "        Pair<NDArray, NDList> pair = net.forward(getInput.apply(), state);\n",
    "        y = pair.getKey();\n",
    "        state = pair.getValue();\n",
    "\n",
    "        outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));\n",
    "    }\n",
    "    StringBuilder output = new StringBuilder();\n",
    "    for (int i : outputs) {\n",
    "        output.append(vocab.idxToToken.get(i));\n",
    "    }\n",
    "    return output.toString();\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在我们可以测试 `predict_ch8` 函数。\n",
    "我们将前缀指定为 `time traveller ` ，并让它生成10个附加字符。\n",
    "鉴于我们没有对网络进行培训，\n",
    "它将产生荒谬的预测。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "predictCh8(\"time traveller \", 10, net, vocab, manager.getDevice());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 梯度裁剪\n",
    "\n",
    "对于长度为 $T$ 的序列，\n",
    "我们在一次迭代中计算这些 $T$ 时间步上的梯度，这导致在反向传播期间产生长度为 $\\mathcal{O}(T)$ 的矩阵乘积链。\n",
    "如 :numref:`sec_numerical_stability`, 中所述，它可能导致数值不稳定，例如，当 $T$ 较大时，梯度可能会爆炸或消失。因此，RNN模型通常需要额外的帮助来稳定训练。\n",
    "\n",
    "一般来说，\n",
    "在解决优化问题时，\n",
    "我们对模型参数采取更新步骤，\n",
    "以向量形式说\n",
    "$\\mathbf{x}$,\n",
    "在小批量上的负梯度方向 $\\mathbf{g}$ \n",
    "例如，\n",
    "以 $\\eta > 0$ 作为学习率，\n",
    "在一次迭代中，我们更新\n",
    "$\\mathbf{x}$\n",
    "作为$\\mathbf{x} - \\eta \\mathbf{g}$.\n",
    "让我们进一步假设目标函数 $f$\n",
    "行为良好，例如， *Lipschitz continuous* ，常数为 $L$.\n",
    "就是说，\n",
    "对于任何 $\\mathbf{x}$ 和 $\\mathbf{y}$ 我们都有\n",
    "\n",
    "$$|f(\\mathbf{x}) - f(\\mathbf{y})| \\leq L \\|\\mathbf{x} - \\mathbf{y}\\|.$$\n",
    "\n",
    "在这种情况下，我们可以安全地假设，如果我们将参数向量更新为 $\\eta \\mathbf{g}$, 那么\n",
    "\n",
    "$$|f(\\mathbf{x}) - f(\\mathbf{x} - \\eta\\mathbf{g})| \\leq L \\eta\\|\\mathbf{g}\\|,$$\n",
    "\n",
    "也就是说\n",
    "我们观察到的变化不会超过 $L \\eta \\|\\mathbf{g}\\|$。这既是坏事也是好事。\n",
    "在坏事方面，\n",
    "它限制了进步的速度；\n",
    "而在好事方面，\n",
    "它限制了如果我们朝着错误的方向前进，事情可能会出错的程度。\n",
    "\n",
    "有时梯度可能相当大，优化算法可能无法收敛。我们可以通过降低学习率 $\\eta$. 但是如果我们*很少* 得到大的梯度呢？在这种情况下，这种做法可能显得毫无根据。一种流行的替代方法是通过将梯度 $\\mathbf{g}$ 投影回给定半径的球，例如 $\\theta$ 来剪裁梯度 $\\mathbf{g}$\n",
    "\n",
    "$$\\mathbf{g} \\leftarrow \\min\\left(1, \\frac{\\theta}{\\|\\mathbf{g}\\|}\\right) \\mathbf{g}.$$\n",
    "\n",
    "通过这样做，我们知道梯度范数永远不会超过 $\\theta$ ，并且\n",
    "更新的梯度与 $\\mathbf{g}$ 的原始方向完全对齐。\n",
    "它还具有限制任何给定影响的理想副作用\n",
    "minibatch（以及其中的任何给定样本）可以应用于参数向量。这\n",
    "赋予模型一定程度的鲁棒性。渐变剪裁提供\n",
    "快速修复渐变爆炸。虽然它不能完全解决问题，但它是缓解问题的众多技术之一。\n",
    "\n",
    "下面我们定义一个函数来剪裁\n",
    "从头开始实现的模型或由高级API构建的模型。\n",
    "还要注意，我们计算了所有模型参数的梯度范数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** 修剪梯度 */\n",
    "public static void gradClipping(RNNModelScratch net, int theta, NDManager manager) {\n",
    "    double result = 0;\n",
    "    for (NDArray p : net.params) {\n",
    "        NDArray gradient = p.getGradient();\n",
    "        gradient.attach(manager);\n",
    "        result += gradient.pow(2).sum().getFloat();\n",
    "    }\n",
    "    double norm = Math.sqrt(result);\n",
    "    if (norm > theta) {\n",
    "        for (NDArray param : net.params) {\n",
    "            NDArray gradient = param.getGradient();\n",
    "            gradient.muli(theta / norm);\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练\n",
    "\n",
    "在训练模特之前，\n",
    "让我们定义一个函数来在一个历元中训练模型。它与我们在三个地方训练： :numref:`sec_softmax_scratch` 模型的方式不同：\n",
    "\n",
    "1. 顺序数据的不同采样方法（随机采样和顺序分区）将导致隐藏状态初始化的差异。\n",
    "2. 在更新模型参数之前，我们剪裁梯度。这确保了模型不会发散，即使在训练过程中的某个点上坡度增大。\n",
    "3. 我们使用困惑度来评估模型。如 :numref:`subsec_perplexity` 中所述，这确保了不同长度的序列具有可比性。\n",
    "\n",
    "\n",
    "明确地\n",
    "当使用顺序分区时，我们仅在每个历元开始时初始化隐藏状态。\n",
    "由于下一个minibatch中的 $i^\\mathrm{th}$ 子序列示例与当前的 $i^\\mathrm{th}$ 子序列示例相邻，\n",
    "当前批处理结束时的隐藏状态\n",
    "将\n",
    "用于初始化\n",
    "下一个迷你批处理开始时的隐藏状态。\n",
    "这样，\n",
    "序列的历史信息\n",
    "以隐藏状态存储\n",
    "可能溢出\n",
    "一个epoch内相邻的子序列。\n",
    "然而，隐藏状态的计算\n",
    "任何时候都取决于以前的所有小批量\n",
    "在同一epoch，\n",
    "这使得梯度计算复杂化。\n",
    "为了降低计算成本，\n",
    "我们在处理任何小批量之前分离梯度\n",
    "使隐态的梯度计算\n",
    "始终限于\n",
    "一个小批量中的时间步长。\n",
    "\n",
    "在使用随机抽样时，\n",
    "我们需要为每个迭代重新初始化隐藏状态，因为每个示例都是使用随机位置采样的。\n",
    "与 :numref:`sec_softmax_scratch` 中的`trainepoch3`函数相同，\n",
    "`updater` 是一个通用函数\n",
    "以更新模型参数。\n",
    "它可以是从头开始实现的函数，也可以是中的内置优化函数\n",
    "深度学习框架。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** 在一个opoch内训练一个模型。 */\n",
    "public static Pair<Double, Double> trainEpochCh8(\n",
    "        RNNModelScratch net,\n",
    "        List<NDList> trainIter,\n",
    "        Loss loss,\n",
    "        voidTwoFunction<Integer, NDManager> updater,\n",
    "        Device device,\n",
    "        boolean useRandomIter) {\n",
    "    StopWatch watch = new StopWatch();\n",
    "    watch.start();\n",
    "    Accumulator metric = new Accumulator(2); // 训练损失总数\n",
    "    try (NDManager childManager = manager.newSubManager()) {\n",
    "        NDList state = null;\n",
    "        for (NDList pair : trainIter) {\n",
    "            NDArray X = pair.get(0).toDevice(device, true);\n",
    "            X.attach(childManager);\n",
    "            NDArray Y = pair.get(1).toDevice(device, true);\n",
    "            Y.attach(childManager);\n",
    "            if (state == null || useRandomIter) {\n",
    "                // 在第一次迭代或\n",
    "                // 使用随机取样\n",
    "                state = net.beginState((int) X.getShape().getShape()[0], device);\n",
    "            } else {\n",
    "                for (NDArray s : state) {\n",
    "                    s.stopGradient();\n",
    "                }\n",
    "            }\n",
    "            state.attach(childManager);\n",
    "\n",
    "            NDArray y = Y.transpose().reshape(new Shape(-1));\n",
    "            X = X.toDevice(device, false);\n",
    "            y = y.toDevice(device, false);\n",
    "            try (GradientCollector gc = manager.getEngine().newGradientCollector()) {\n",
    "                Pair<NDArray, NDList> pairResult = net.forward(X, state);\n",
    "                NDArray yHat = pairResult.getKey();\n",
    "                state = pairResult.getValue();\n",
    "                NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();\n",
    "                gc.backward(l);\n",
    "                metric.add(new float[] {l.getFloat() * y.size(), y.size()});\n",
    "            }\n",
    "            gradClipping(net, 1, childManager);\n",
    "            updater.apply(1, childManager); // 因为已经调用了“mean”函数\n",
    "        }\n",
    "    }\n",
    "    return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练功能支持\n",
    "实现了一个RNN模型\n",
    "要么从头开始\n",
    "或者使用高级API。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** 训练一个模型 */\n",
    "public static void trainCh8(\n",
    "        RNNModelScratch net,\n",
    "        List<NDList> trainIter,\n",
    "        Vocab vocab,\n",
    "        int lr,\n",
    "        int numEpochs,\n",
    "        Device device,\n",
    "        boolean useRandomIter) {\n",
    "    SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();\n",
    "    Animator animator = new Animator();\n",
    "    // 初始化\n",
    "    voidTwoFunction<Integer, NDManager> updater =\n",
    "            (batchSize, subManager) -> Training.sgd(net.params, lr, batchSize, subManager);\n",
    "    Function<String, String> predict = (prefix) -> predictCh8(prefix, 50, net, vocab, device);\n",
    "    // 训练和推理\n",
    "    double ppl = 0.0;\n",
    "    double speed = 0.0;\n",
    "    for (int epoch = 0; epoch < numEpochs; epoch++) {\n",
    "        Pair<Double, Double> pair =\n",
    "                trainEpochCh8(net, trainIter, loss, updater, device, useRandomIter);\n",
    "        ppl = pair.getKey();\n",
    "        speed = pair.getValue();\n",
    "        if ((epoch + 1) % 10 == 0) {\n",
    "            animator.add(epoch + 1, (float) ppl, \"\");\n",
    "            animator.show();\n",
    "        }\n",
    "    }\n",
    "    System.out.format(\n",
    "            \"perplexity: %.1f, %.1f tokens/sec on %s%n\", ppl, speed, device.toString());\n",
    "    System.out.println(predict.apply(\"time traveller\"));\n",
    "    System.out.println(predict.apply(\"traveller\"));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在我们可以训练循环神经网络模型。\n",
    "因为我们在数据集中只使用10000个标记，所以模型需要更多的时间来更好地收敛"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 500);\n",
    "\n",
    "int lr = 1;\n",
    "trainCh8(net, trainIter, vocab, lr, numEpochs, manager.getDevice(), false);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后\n",
    "让我们检查一下使用随机抽样方法的结果。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainCh8(net, trainIter, vocab, lr, numEpochs, manager.getDevice(), true);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从零开始实现上述循环神经网络模型，\n",
    "虽然有指导意义，但是并不方便。\n",
    "在下一节中，我们将学习如何改进循环神经网络模型。\n",
    "例如，如何使其实现地更容易，且运行速度更快。\n",
    "\n",
    "\n",
    "## 总结\n",
    "\n",
    "* 我们可以训练一个基于循环神经网络的字符级语言模型，根据用户提供的文本的前缀生成后续文本。\n",
    "* 一个简单的循环神经网络语言模型包括输入编码、循环神经网络模型和输出生成。\n",
    "* 循环神经网络模型在训练以前需要初始化状态，不过随机抽样和顺序划分使用初始化方法不同。\n",
    "* 当使用顺序划分时，我们需要分离梯度以减少计算量。\n",
    "* 在进行任何预测之前，模型通过预热期进行自我更新（例如，获得比初始值更好的隐状态）。\n",
    "* 梯度裁剪可以防止梯度爆炸，但不能应对梯度消失。\n",
    "\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 显示一个热编码相当于为每个对象选择不同的嵌入。\n",
    "2. 调整超参数（例如，epoch数、隐藏单元数、小批量时间步数和学习速率）以改善困惑。\n",
    "    * 你能降到多低？\n",
    "    * 用可学习的嵌入替换一个热编码。这会导致更好的性能吗？\n",
    "    * 它在H.G.威尔斯的其他书籍上的效果如何，例如 [*世界大战*](http://www.gutenberg.org/ebooks/36)?\n",
    "3. 修改预测函数，例如使用采样，而不是拾取最可能的下一个字符。\n",
    "    * 会发生什么？\n",
    "    * 使模型偏向更可能的输出，例如，通过从 $q(x_t \\mid x_{t-1}, \\ldots, x_1) \\propto P(x_t \\mid x_{t-1}, \\ldots, x_1)^\\alpha$ for $\\alpha > 1$进行采样。\n",
    "4. 在不剪切梯度的情况下运行本节中的代码。会发生什么？\n",
    "5. 更改顺序分区，使其不会从计算图中分离隐藏状态。运行时间有变化吗？那么困惑呢？\n",
    "6. 用ReLU替换本节中使用的激活功能，并重复本节中的实验。我们还需要梯度剪裁吗？为什么？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
